{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2087a56",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "24a58336-3b16-491b-8646-2f54e93a8964",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "df670cb5-24cb-4683-849e-2e27769dd762",
   "metadata": {},
   "outputs": [],
   "source": [
    "def ner_tags_to_spans(samples, tag_to_id):\n",
    "    \"\"\"\n",
    "    Converts NER tags in the dataset samples to spans (start, end, entity type).\n",
    "    \n",
    "    Args:\n",
    "        samples (dict): A dictionary containing the tokens and NER tags.\n",
    "        tag_to_id (dict): A dictionary mapping NER tags to IDs.\n",
    "    \n",
    "    Returns:\n",
    "        dict: A dictionary containing tokenized text and corresponding NER spans.\n",
    "    \"\"\"\n",
    "    ner_tags = samples[\"ner_tags\"]\n",
    "    id_to_tag = {v: k for k, v in tag_to_id.items()}\n",
    "    spans = []\n",
    "    start_pos = None\n",
    "    entity_name = None\n",
    "\n",
    "    for i, tag in enumerate(ner_tags):\n",
    "        if tag == 0:  # 'O' tag\n",
    "            if entity_name is not None:\n",
    "                spans.append((start_pos, i - 1, entity_name))\n",
    "                entity_name = None\n",
    "                start_pos = None\n",
    "        else:\n",
    "            tag_name = id_to_tag[tag]\n",
    "            if tag_name.startswith('B-'):\n",
    "                if entity_name is not None:\n",
    "                    spans.append((start_pos, i - 1, entity_name))\n",
    "                entity_name = tag_name[2:]\n",
    "                start_pos = i\n",
    "            elif tag_name.startswith('I-'):\n",
    "                continue\n",
    "\n",
    "    # Handle the last entity if the sentence ends with an entity\n",
    "    if entity_name is not None:\n",
    "        spans.append((start_pos, len(samples[\"tokens\"]) - 1, entity_name))\n",
    "    \n",
    "    return {\"tokenized_text\": samples[\"tokens\"], \"ner\": spans}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "971f92b9-ece2-460d-99f8-73277b5d3081",
   "metadata": {},
   "outputs": [],
   "source": [
    "# step 1: load data\n",
    "dataset = load_dataset(\"eriktks/conll2003\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "67a18f87-1571-4e8c-8253-6e0305bfa0cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 2: Define NER tag-to-ID mapping\n",
    "tag_to_id = {\n",
    "    'O': 0, 'B-person': 1, 'I-person': 2, 'B-organization': 3, 'I-organization': 4,\n",
    "    'B-location': 5, 'I-location': 6, 'B-others': 7, 'I-others': 8\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "354aae86-2e5f-4a82-821b-6baba9438532",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert NER tags to spans for the training data\n",
    "gliner_data_conll = [ner_tags_to_spans(i, tag_to_id) for i in dataset['train']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "7c717148-7c98-4998-90fc-c244e11d7b67",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the pre-trained GLiNER model\n",
    "from gliner import GLiNER\n",
    "import torch\n",
    "\n",
    "model = GLiNER.from_pretrained(\"urchade/gliner_small\", load_tokenizer=True) #true if a model was trained from scratch with new code base\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = \"cuda\"\n",
    "else:\n",
    "    device = \"cpu\"\n",
    "\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "601c2e03-2fe7-481f-b769-c8c874bee9c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the model on the first 100 samples\n",
    "evaluation_results = model.evaluate(\n",
    "    gliner_data_conll[:100], flat_ner=True, entity_types=[\"person\", \"organization\", \"location\", \"others\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "273d79be-2f0f-4191-bc8e-641854ffa540",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('P: 63.13%\\tR: 71.43%\\tF1: 67.02%\\n', 0.6702412868632708)\n"
     ]
    }
   ],
   "source": [
    "print(evaluation_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8c6190d-6a63-43c0-9010-27d141970877",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
